from numpyro.distributions import * 
import numpyro.distributions.util as util
import jax.numpy as np
import jax
import scalevi.utils as utils
def DiagonalNormal(mu, cov):
    return Independent(Normal(mu, cov ** 0.5), 1)

class CustomBlockMultivariateNormal():
    def __init__(self, loc_0, scale_tril_0, loc_1, scale_tril_1):
        self.loc_0 = loc_0
        self.loc_1 = loc_1
        self.scale_tril_0 = scale_tril_0
        self.scale_tril_1 = scale_tril_1
        self.base_dist_0 = CustomMultivariateNormal(loc = loc_0, scale_tril = scale_tril_0)
        self.base_dist_1 = CustomMultivariateNormal(loc = loc_1, scale_tril = scale_tril_1)
        self.D_0 = loc_0.shape[-1]
        self.D_1 = loc_1.shape[-1]

    def sample(self, key, sample_shape=()):
        z_0 = self.base_dist_0.sample(jax.random.fold_in(key, 0), sample_shape)
        z_1 = self.base_dist_1.sample(jax.random.fold_in(key, 1), sample_shape)
        return np.append(z_0, z_1, -1)

    def log_prob(self, value):
        z_0 = value[..., :self.D_0]
        z_1 = value[..., self.D_0:]
        return self.base_dist_0.log_prob(z_0) + self.base_dist_1.log_prob(z_1)

    def sample_and_log_prob(self, key, sample_shape = ()):
        z_0, log_p_0 = self.base_dist_0.sample_and_log_prob(jax.random.fold_in(key, 0), sample_shape)
        z_1, log_p_1 = self.base_dist_1.sample_and_log_prob(jax.random.fold_in(key, 1), sample_shape)
        return np.append(z_0, z_1, -1), log_p_0+log_p_1


class CustomMultivariateNormal(MultivariateNormal):
    def sample_and_log_prob(self, key, sample_shape = ()):
        eps = jax.random.normal(
                            key, 
                            shape=sample_shape + self.batch_shape + self.event_shape)
        z = self.loc + np.squeeze(np.matmul(self.scale_tril, eps[..., np.newaxis]), axis=-1)

        log_jac = np.log(np.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
        normalize_term = self.scale_tril.shape[-1] * np.log(2 * np.pi)
        M = utils.vtv(eps)
        return z, -0.5 * (M + normalize_term)  - log_jac

class CustomNormal(Normal):
    def sample_and_log_prob(self, key, sample_shape = ()):
        # assert is_prng_key(key)
        eps = jax.random.normal(key, shape=sample_shape + self.batch_shape + self.event_shape)
        z = self.loc + eps * self.scale
        normalize_term = np.log(np.sqrt(2 * np.pi) * self.scale)
        return z, -0.5*eps**2 - normalize_term

class CustomIndependent(Independent):
    def sample_and_log_prob(self, key, sample_shape=()):
        z, log_prob = self.base_dist.sample_and_log_prob(key, sample_shape)
        return z, util.sum_rightmost(log_prob, self.reinterpreted_batch_ndims)

def CustomDiagonalNormal(mu, sig):
    return CustomIndependent(CustomNormal(mu, sig), 1)

# class CustomMultivariateNormal(MultivariateNormal):
#     def sample_and_log_prob(self, key, sample_shape = ()):
#         eps = jax.random.normal(
#                             key, 
#                             shape=sample_shape + self.batch_shape + self.event_shape)
#         z = self.loc + np.squeeze(np.matmul(self.scale_tril, eps[..., np.newaxis]), axis=-1)

#         log_jac = np.log(np.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1)
#         normalize_term = self.scale_tril.shape[-1] * np.log(2 * np.pi)
#         M = utils.vtv(eps)
#         return z, -0.5 * (M + normalize_term)  - log_jac


#  class ParameterizedGaussian(MultivariateNormal):
#     def __init__(self, loc, scale_tril, scale_transform):
#         self.transform = scale_transform
#         super(ParameterizedGaussian, self).__init__()
#     def apply_scale_transform(self):
#         return self.transform(self.scale_tril)

#     return MultivariateNormal(
#             loc=self.loc,
#             scale_tril=self.apply_scale_transform(self.scale_tril))